import torch
import numpy as np
import random
import qm9.utils as qm9utils

from utils import utils as diffusion_utils
from dadm.egnn_diffusion import EnVariationalDiffusion
from dadm.egnn_vae import EnHierarchicalVAE

from utils.utils import assert_mean_zero_with_mask, remove_mean_with_mask, \
    assert_correctly_masked, sample_center_gravity_zero_gaussian_with_mask


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


class PropertyClassifier(torch.nn.Module):
    def __init__(self, hidden_nf, act_fn=torch.nn.SiLU(), n_layers=4):
        super(PropertyClassifier, self).__init__()
        self.hidden_nf = hidden_nf
        self.n_layers = n_layers

        self.node_dec = torch.nn.Sequential(torch.nn.Linear(self.hidden_nf, self.hidden_nf),
                                            act_fn,
                                            torch.nn.Linear(self.hidden_nf, self.hidden_nf))

        self.graph_dec = torch.nn.Sequential(torch.nn.Linear(self.hidden_nf, self.hidden_nf),
                                             act_fn,
                                             torch.nn.Linear(self.hidden_nf, 1))

    def forward(self, h, node_mask, n_nodes):
        h = self.node_dec(h)
        h = h * node_mask
        h = h.view(-1, n_nodes, self.hidden_nf)
        h = torch.sum(h, dim=1)
        pred = self.graph_dec(h)
        return pred.squeeze(1)


class DomainAdaptiveDiffusion(EnVariationalDiffusion):
    """
    The E(n) Latent Diffusion Module.
    """

    def __init__(self, **kwargs):
        vae = kwargs.pop('vae')
        trainable_ae = kwargs.pop('trainable_ae', False)
        super().__init__(**kwargs)

        # Create self.vae as the first stage model.
        self.trainable_ae = trainable_ae
        self.instantiate_first_stage(vae)
        # MARK difference
        # self.property_classifier = PropertyClassifier(hidden_nf=vae.latent_node_nf)
        # self.loss_l1 = torch.nn.L1Loss()

    def unnormalize_z(self, z, node_mask):
        # Overwrite the unnormalize_z function to do nothing (for sample_chain).

        # Parse from z
        x, h_cat = z[:, :, 0:self.n_dims], z[:, :, self.n_dims:self.n_dims + self.num_classes]
        h_int = z[:, :, self.n_dims + self.num_classes:self.n_dims + self.num_classes + 1]
        assert h_int.size(2) == self.include_charges

        # Unnormalize
        # x, h_cat, h_int = self.unnormalize(x, h_cat, h_int, node_mask)
        output = torch.cat([x, h_cat, h_int], dim=2)
        return output

    def log_constants_p_h_given_z0(self, h, node_mask):
        """Computes p(h|z0)."""
        batch_size = h.size(0)

        n_nodes = node_mask.squeeze(2).sum(1)  # N has shape [B]
        assert n_nodes.size() == (batch_size,)
        degrees_of_freedom_h = n_nodes * self.n_dims

        zeros = torch.zeros((h.size(0), 1), device=h.device)
        gamma_0 = self.gamma(zeros)

        # Recall that sigma_x = sqrt(sigma_0^2 / alpha_0^2) = SNR(-0.5 gamma_0).
        log_sigma_x = 0.5 * gamma_0.view(batch_size)

        return degrees_of_freedom_h * (- log_sigma_x - 0.5 * np.log(2 * np.pi))

    # def sample_p_xh_given_z0(self, z0, node_mask, edge_mask, context, fix_noise=False):
    #     """Samples x ~ p(x|z0)."""
    #     zeros = torch.zeros(size=(z0.size(0), 1), device=z0.device)
    #     gamma_0 = self.gamma(zeros)
    #     # Computes sqrt(sigma_0^2 / alpha_0^2)
    #     sigma_x = self.SNR(-0.5 * gamma_0).unsqueeze(1)
    #     net_out = self.phi(z0, zeros, node_mask, edge_mask, context)
    #
    #     # Compute mu for p(zs | zt).
    #     mu_x = self.compute_x_pred(net_out, z0, gamma_0)
    #     xh = self.sample_normal(mu=mu_x, sigma=sigma_x, node_mask=node_mask, fix_noise=fix_noise)
    #
    #     x = xh[:, :, :self.n_dims]
    #
    #     # h_int = z0[:, :, -1:] if self.include_charges else torch.zeros(0).to(z0.device)
    #     # x, h_cat, h_int = self.unnormalize(x, z0[:, :, self.n_dims:-1], h_int, node_mask)
    #
    #     # h_cat = F.one_hot(torch.argmax(h_cat, dim=2), self.num_classes) * node_mask
    #     # h_int = torch.round(h_int).long() * node_mask
    #
    #     # Make the mol_data structure compatible with the EnVariationalDiffusion sample() and sample_chain().
    #     h = {'integer': xh[:, :, self.n_dims:], 'categorical': torch.zeros(0).to(xh)}
    #
    #     return x, h

    # def log_pxh_given_z0_without_constants(
    #         self, x, h, z_t, gamma_0, eps, net_out, node_mask, epsilon=1e-10):
    #
    #     # Computes the error for the distribution N(latent | 1 / alpha_0 z_0 + sigma_0/alpha_0 eps_0, sigma_0 / alpha_0),
    #     # the weighting in the epsilon parametrization is exactly '1'.
    #     log_pxh_given_z_without_constants = -0.5 * self.compute_error(net_out, gamma_0, eps)
    #
    #     # Combine log probabilities for x and h.
    #     log_p_xh_given_z = log_pxh_given_z_without_constants
    #
    #     return log_p_xh_given_z

    def vae_self_condition(self, x, h, node_mask=None, edge_mask=None, context=None, scaffold_mask=None):
        # Mask before AE
        if scaffold_mask is not None:
            # For scaffold mask all Zeors then no mask.
            mask = torch.all(scaffold_mask == 0, dim=-1)
            scaffold_mask[mask] = 1
            bs, ns, _ = node_mask.size()
            scaffold_mask_expanded = scaffold_mask.unsqueeze(-1)
            masked_x = x * scaffold_mask_expanded
            masked_h = {'categorical': h['categorical'] * scaffold_mask_expanded,
                        'integer': h['integer'] * scaffold_mask_expanded}
            masked_node_mask = node_mask * scaffold_mask_expanded
            masked_edge_mask = edge_mask.view(bs, ns, ns) * torch.bmm(scaffold_mask_expanded,
                                                                      scaffold_mask.unsqueeze(1))
            masked_edge_mask = masked_edge_mask.view(bs, ns * ns)
            masked_x = remove_mean_with_mask(masked_x, masked_node_mask)
            if context is not None:
                masked_context = context * scaffold_mask_expanded
            else:
                masked_context = context
            # Encode mol_data to latent space.
            z_x_mu, z_x_sigma, z_h_mu, z_h_sigma = self.vae.encode(masked_x, masked_h, masked_node_mask,
                                                                   masked_edge_mask, masked_context)
        else:
            masked_node_mask = node_mask
            masked_edge_mask = edge_mask
            masked_context = context
            z_x_mu, z_x_sigma, z_h_mu, z_h_sigma = self.vae.encode(x, h, masked_node_mask, masked_edge_mask, masked_context)
        # Compute fixed sigma values.
        t_zeros = torch.zeros(size=(x.size(0), 1), device=x.device)
        gamma_0 = self.inflate_batch_array(self.gamma(t_zeros), x)
        sigma_0 = self.sigma(gamma_0, x)

        # Infer latent z.
        z_xh_mean = torch.cat([z_x_mu, z_h_mu], dim=2)
        diffusion_utils.assert_correctly_masked(z_xh_mean, masked_node_mask)
        z_xh_sigma = sigma_0
        # z_xh_sigma = torch.cat([z_x_sigma.expand(-1, -1, 3), z_h_sigma], dim=2)
        z_xh = self.vae.sample_normal(z_xh_mean, z_xh_sigma, masked_node_mask)
        # z_xh = z_xh_mean
        z_xh = z_xh.detach()  # Always keep the encoder fixed.
        diffusion_utils.assert_correctly_masked(z_xh, masked_node_mask)

        # Compute reconstruction loss.
        if self.trainable_ae:
            xh = torch.cat([x, h['categorical'], h['integer']], dim=2)
            # Decoder output (reconstruction).
            x_recon, h_recon = self.vae.decoder._forward(z_xh, masked_node_mask, masked_edge_mask, masked_context)
            xh_rec = torch.cat([x_recon, h_recon], dim=2)
            loss_recon = self.vae.compute_reconstruction_error(xh_rec, xh)
        else:
            loss_recon = 0

        z_x = z_xh[:, :, :self.n_dims]
        diffusion_utils.assert_mean_zero_with_mask(z_x, scaffold_mask.unsqueeze(-1))
        if self.dynamics.self_condition_nf == 1:
            # Shuffule, idea from classifier-free DM
            # permutation = torch.randperm(z_xh.size(0)).to(z_xh.device)
            # z_xh = torch.index_select(z_xh, 0, permutation)
            z_xh = z_xh * node_mask
            self_condition = z_xh[:, :, self.n_dims:self.n_dims + 1]
        else:
            self_condition = torch.cat((z_x, z_xh[:, :, self.n_dims:self.dynamics.self_condition_nf]),
                                       dim=2) if self.dynamics.self_condition_nf > 0 else None
        return self_condition, loss_recon

    def forward(self, x, h, node_mask=None, edge_mask=None, context=None, scaffold_mask=None):
        """
        Computes the loss (type l2 or NLL) if training. And if eval then always computes NLL.
        """

        # Concat z_xh into context as self condition
        if self.dynamics.self_condition_nf > 0:
            z_xh, loss_recon = self.vae_self_condition(x, h, node_mask, edge_mask, context, scaffold_mask)
            if context is not None and z_xh is not None:
                b, n, _ = x.size()
                context = torch.concat([context, z_xh], dim=2)
            elif z_xh is not None:
                context = z_xh
        else:
            loss_recon = 0

        # Normalize mol_data, take into account volume change in x.
        x, h, delta_log_px = self.normalize(x, h, node_mask)
        if self.training:
            # Only 1 forward pass when t0_always is False.
            loss_ld, loss_dict = self.compute_loss(x, h, node_mask, edge_mask, context, t0_always=False)
        else:
            # Less variance in the estimator, costs two forward passes.
            loss_ld, loss_dict = self.compute_loss(x, h, node_mask, edge_mask, context, t0_always=True)

        # The _constants_ depending on sigma_0 from the
        # cross entropy term E_q(z0 | x) [log p(x | z0)].
        neg_log_constants = -self.log_constants_p_h_given_z0(
            torch.cat([h['categorical'], h['integer']], dim=2), node_mask)
        # Reset constants during training with l2 loss.
        if self.training and self.loss_type == 'l2':
            neg_log_constants = torch.zeros_like(neg_log_constants)

        neg_log_pxh = loss_ld + loss_recon + neg_log_constants

        return neg_log_pxh

    @torch.no_grad()
    def process_target_data(self, args, data, device, is_context):
        dtype = torch.float32
        x = data['positions'].to(device, dtype)
        node_mask = data['atom_mask'].to(device, dtype).unsqueeze(2)
        edge_mask = data['edge_mask'].to(device, dtype)
        one_hot = data['one_hot'].to(device, dtype)
        charges = (data['charges'] if self.include_charges else torch.zeros(0)).to(device, dtype)
        try:
            scaffold_mask = data['scaffold_mask'].to(device, dtype)
        except:
            bs, ns, _ = node_mask.size()
            scaffold_mask = torch.zeros([bs, ns]).to(device, dtype)
            try:
                masked_length = data['num_atoms']
            except:
                tensor = data['atom_mask'].to(dtype)
                masked_length = [torch.nonzero(row).size(0) for row in tensor]
            for i in range(bs):
                ones_indices = random.sample(range(masked_length[i]), int((1 - args.mask_ratio) * masked_length[i]))
                scaffold_mask[i, ones_indices] = 1
        x = remove_mean_with_mask(x, node_mask)
        diffusion_utils.assert_mean_zero_with_mask(x, node_mask)
        h = {'categorical': one_hot, 'integer': charges}
        if is_context:
            context = data['context']
        else:
            context = None
        z_h, _ = self.vae_self_condition(x, h, node_mask, edge_mask, context, scaffold_mask)
        return z_h, node_mask, edge_mask, context, scaffold_mask

    @torch.no_grad()
    def process_pas_target_data(self, data, device, is_context):
        dtype = torch.float32
        x = data[0].to(device, dtype)
        node_mask = data[1].to(device, dtype).unsqueeze(2)
        edge_mask = data[2].to(device, dtype)
        one_hot = data[3].to(device, torch.int64)

        x = remove_mean_with_mask(x, node_mask)
        charges = torch.zeros(0).to(x.device)

        x = remove_mean_with_mask(x, node_mask)
        diffusion_utils.assert_mean_zero_with_mask(x, node_mask)
        h = {'categorical': one_hot, 'integer': charges}
        if is_context:
            context = data['context']
        else:
            context = None
        z_h, _ = self.vae_self_condition(x, h, node_mask, edge_mask, context)
        return z_h, node_mask, edge_mask, context

    @torch.no_grad()
    def sample(self, n_samples, n_nodes, target_data, node_mask, edge_mask, context, fix_noise=False, args=None):
        """
        Draw samples from the generative model.
        """
        is_context = False if context == None else True
        if self.dynamics.self_condition_nf > 0:
            if args is not None and args.pas_data:
                self_condition, node_mask, edge_mask, context = self.process_pas_target_data(target_data,
                                                                                             node_mask.device,
                                                                                             is_context)
            else:
                if self.dynamics.self_condition_nf == 1:
                    self_condition, _, _, _ = self.process_target_data(args, target_data, node_mask.device, is_context)
                else:
                    # self_condition, node_mask, edge_mask, context, scaffold_mask = self.process_target_data(target_data,
                    #                                                                                         node_mask.device,
                    #                                                                                         is_context)
                    self_condition, _, _, _, _ = self.process_target_data(args, target_data, node_mask.device, is_context)
        else:
            self_condition = None
        n_nodes = node_mask.size(1)
        if self_condition is not None and self_condition.size(1) < n_nodes:
            diff = n_nodes - self_condition.size(1)
            zeros = torch.zeros((self_condition.size(0), diff, self_condition.size(2)), device=self_condition.device)
            self_condition = torch.concat([self_condition, zeros], 1)
        if context is not None and self_condition is not None:
            context = context[:, :n_nodes, :] * node_mask
            context = torch.concat([context, self_condition], 2)
        elif self_condition is not None:
            context = self_condition
        x, h = super().sample(n_samples, n_nodes, node_mask, edge_mask, context, fix_noise)

        # z_xh = torch.cat([z_x, z_h['categorical'], z_h['integer']], dim=2)
        # diffusion_utils.assert_correctly_masked(z_xh, node_mask)
        # x, h = self.vae.decode(z_xh, node_mask, edge_mask, context)

        return x, h, node_mask

    @torch.no_grad()
    def sample_chain(self, n_samples, n_nodes, target_data, node_mask, edge_mask, context, keep_frames=None,
                     args=None):
        """
        Draw samples from the generative model, keep the intermediate states for visualization purposes.
        """
        is_context = False if context == None else True
        if args is not None and args.pas_data:
            self_condition, node_mask, edge_mask, context = self.process_pas_target_data(target_data, node_mask.device,
                                                                                         is_context)
        else:
            if self.dynamics.self_condition_nf == 1:
                self_condition, _, _, _ = self.process_target_data(args, target_data, node_mask.device, is_context)
            else:
                # self_condition, node_mask, edge_mask, context = self.process_target_data(target_data, node_mask.device,
                #                                                                          is_context)
                self_condition, _, _, _, _ = self.process_target_data(args, target_data, node_mask.device, is_context)
        n_nodes = node_mask.size(1)
        if self_condition is not None and self_condition.size(1) < n_nodes:
            diff = n_nodes - self_condition.size(1)
            zeros = torch.zeros((self_condition.size(0), diff, self_condition.size(2)), device=self_condition.device)
            self_condition = torch.concat([self_condition, zeros], 1)
        if context is not None and self_condition is not None:
            context = context[:, :n_nodes, :] * node_mask
            context = torch.concat([context, self_condition], 2)
        elif self_condition is not None:
            context = self_condition
        chain_flat = super().sample_chain(n_samples, n_nodes, node_mask, edge_mask, context, keep_frames)

        # xh = torch.cat([x, h['categorical'], h['integer']], dim=2)
        # chain[0] = xh  # Overwrite last frame with the resulting x and h.

        # chain_flat = chain.view(n_samples * keep_frames, *z.size()[1:])

        # chain = chain_flat.view(keep_frames, n_samples, *chain_flat.size()[1:])
        # chain_decoded = torch.zeros(
        #     size=(*chain.size()[:-1], self.vae.in_node_nf + self.vae.n_dims), device=chain.device)
        #
        # for i in range(keep_frames):
        #     z_xh = chain[i]
        #     diffusion_utils.assert_mean_zero_with_mask(z_xh[:, :, :self.n_dims], node_mask)
        #
        #     x, h = self.vae.decode(z_xh, node_mask, edge_mask, context)
        #     xh = torch.cat([x, h['categorical'], h['integer']], dim=2)
        #     chain_decoded[i] = xh
        #
        # chain_decoded_flat = chain_decoded.view(n_samples * keep_frames, *chain_decoded.size()[2:])

        return chain_flat

    def instantiate_first_stage(self, vae: EnHierarchicalVAE):
        if not self.trainable_ae:
            self.vae = vae.eval()
            self.vae.train = disabled_train
            for param in self.vae.parameters():
                param.requires_grad = False
        else:
            self.vae = vae.train()
            for param in self.vae.parameters():
                param.requires_grad = True